#![allow(clippy::unwrap_used)]
mod codewriter;

use anyhow::Result;
use naga::{Handle, Type, proc::GlobalCtx, valid::Capabilities};
use naga_oil::compose::{
    ComposableModuleDescriptor, Composer, ComposerError, NagaModuleDescriptor,
};
use regex::Regex;
use std::{borrow::Cow, collections::HashMap, io, path::Path, sync::OnceLock};
use thiserror::Error;
use wgpu::naga::{self, common::wgsl::TypeContext};

const DECORATION_PRE: &str = "X_naga_oil_mod_X";
const DECORATION_POST: &str = "X";

enum ModuleInfo {
    Include {
        constants: HashMap<String, Vec<String>>,
        types: HashMap<String, Vec<String>>,
    },
    File {
        path: String,
        wg_size: [u32; 3],
        constants: HashMap<String, Vec<String>>,
        types: HashMap<String, Vec<String>>,
    },
}

pub fn make_valid_rust_import(value: &str) -> String {
    let v = value.replace("\"../", "").replace('"', "");
    std::path::Path::new(&v)
        .file_stem()
        .and_then(|name| name.to_str())
        .unwrap_or(&v)
        .to_owned()
}

// https://github.com/bevyengine/naga_oil/blob/master/src/compose/mod.rs#L417-L419
fn decode(from: &str) -> String {
    String::from_utf8(data_encoding::BASE32_NOPAD.decode(from.as_bytes()).unwrap()).unwrap()
}

// https://github.com/bevyengine/naga_oil/blob/master/src/compose/mod.rs#L355-L363
fn undecorate_regex() -> &'static Regex {
    static MEM: OnceLock<Regex> = OnceLock::new();

    MEM.get_or_init(|| {
        Regex::new(
            format!(
                r"(\x1B\[\d+\w)?([\w\d_]+){}([A-Z0-9]*){}",
                regex_syntax::escape(DECORATION_PRE),
                regex_syntax::escape(DECORATION_POST)
            )
            .as_str(),
        )
        .unwrap()
    })
}

// https://github.com/bevyengine/naga_oil/blob/master/src/compose/mod.rs#L421-L431
fn demangle_str(string: &str) -> Cow<'_, str> {
    undecorate_regex().replace_all(string, |caps: &regex::Captures| {
        format!(
            "{}{}::{}",
            caps.get(1).map_or("", |cc| cc.as_str()),
            make_valid_rust_import(&decode(caps.get(3).unwrap().as_str())),
            caps.get(2).unwrap().as_str()
        )
    })
}

fn mod_name_from_mangled(string: &str) -> (String, String) {
    let demangled = demangle_str(string);
    let mut parts = demangled.as_ref().split("::").collect::<Vec<&str>>();
    let name = parts.pop().unwrap().to_owned();
    let mod_name = parts.join("::");
    (mod_name, name)
}

fn rust_type_name(ty: Handle<naga::Type>, ctx: &GlobalCtx) -> String {
    let wgsl_name = ctx.type_to_string(ty);

    match wgsl_name.as_str() {
        "i32" | "u32" | "f32" => wgsl_name,
        "atomic<u32>" => "u32".to_owned(),
        "atomic<i32>" => "i32".to_owned(),
        "vec2<f32>" => "[f32; 2]".to_owned(),
        "vec4<f32>" => "[f32; 4]".to_owned(),
        "mat4x4<f32>" => "[[f32; 4]; 4]".to_owned(),
        "vec2<u32>" => "[u32; 2]".to_owned(),
        "vec2<i32>" => "[i32; 2]".to_owned(),
        "vec3<u32>" => "[u32; 4]".to_owned(),
        "vec3<f32>" => "[f32; 4]".to_owned(),
        "vec4<u32>" => "[u32; 4]".to_owned(),
        _ => panic!("{}", wgsl_name),
    }
}

fn alignment_of(ty: Handle<Type>, ctx: &GlobalCtx) -> usize {
    let wgsl_name = ctx.type_to_string(ty);

    match wgsl_name.as_str() {
        "i32" | "u32" | "f32" | "atomic<u32>" | "atomic<i32>" => 4,
        "vec2<f32>" | "vec2<u32>" | "vec2<i32>" => 8,
        "vec3<f32>" | "vec4<f32>" | "mat4x4<f32>" | "vec4<u32>" => 16,
        _ => panic!("{}", wgsl_name),
    }
}

#[derive(Debug, Error)]
pub enum GenError {
    #[error("Failed to generate shader module.\n{1}")]
    ImportError(#[source] Box<ComposerError>, String),
    #[error("Failed to read/write input files {0}")]
    IoError(#[from] io::Error),
}

pub fn build_modules(paths: &[&str], includes: &[&str], output_path: &str) -> Result<(), GenError> {
    let mut code = codewriter::CodeWriter::new();
    // Make file as generated so cargo fmt skips it.
    code.add_lines(&[
        "// This file is @generated by brush-wgsl from source wgsl files. Do not edit.",
        "#![allow(dead_code, unused_mut, trivial_numeric_casts)]",
        "#![allow(clippy::all)]",
        "#![allow(clippy::derive_partial_eq_without_eq)]", // Not sure why this isn't part of clippy::all
    ]);
    code.add_lines(&[
        "#[rustfmt::skip]",
        "fn create_composer() -> naga_oil::compose::Composer {",
        "let mut composer = naga_oil::compose::Composer::default().with_capabilities(
            wgpu::naga::valid::Capabilities::all()
        );",
    ]);
    let mut composer = Composer::default().with_capabilities(Capabilities::all());
    let mut modules = HashMap::new();
    for include in includes {
        let helper_source = &std::fs::read_to_string(include)?;
        let include_name = make_valid_rust_import(include);
        composer
            .add_composable_module(ComposableModuleDescriptor {
                source: helper_source,
                file_path: include,
                as_name: Some(include_name.clone()),
                ..Default::default()
            })
            .expect("Failed to add module");

        println!("cargo::rerun-if-changed={include}");

        // Get the input path relative to the output directory
        let relative_path = pathdiff::diff_paths(
            Path::new(&include),
            Path::new(&output_path).parent().unwrap(),
        )
        .unwrap();

        let relative_path_str = relative_path.to_string_lossy().replace('\\', "\\\\");

        code.add_lines(&[
            "composer.add_composable_module(naga_oil::compose::ComposableModuleDescriptor {",
            &format!("source: include_str!(\"./{relative_path_str}\"),"),
            &format!("file_path: \"{relative_path_str}\","),
            &format!("as_name: Some(\"{include_name}\".to_owned()),"),
            "..Default::default()",
            "}).expect(\"Failed to add module\");",
        ]);

        modules.insert(
            include_name,
            ModuleInfo::Include {
                constants: HashMap::new(),
                types: HashMap::new(),
            },
        );
    }

    code.add_lines(&["composer", "}"]);

    for path in paths {
        println!("cargo::rerun-if-changed={path}");

        let source = &std::fs::read_to_string(path)?;
        let module = match composer.make_naga_module(NagaModuleDescriptor {
            source,
            file_path: path,
            ..Default::default()
        }) {
            Ok(m) => m,
            Err(e) => {
                let str = e.emit_to_string(&composer);
                return Err(GenError::ImportError(Box::new(e), str));
            }
        };

        // get file name as module name
        let entries = &module.entry_points;
        assert!(entries.len() == 1, "Must have 1 entry per file");

        let entry = &entries[0];
        let mod_name = make_valid_rust_import(path);
        let ctx = &module.to_ctx();

        let mut constants = HashMap::new();

        for t in module.constants.iter() {
            let type_and_value = match module.global_expressions[t.1.init] {
                naga::Expression::Literal(literal) => match literal {
                    naga::Literal::F64(v) => Some(format!("f64 = {v} as f64")),
                    naga::Literal::F32(v) => Some(format!("f32 = {v} as f32")),
                    naga::Literal::U32(v) => Some(format!("u32 = {v}")),
                    naga::Literal::I32(v) => Some(format!("i32 = {v}")),
                    naga::Literal::Bool(v) => Some(format!("bool = {v}")),
                    naga::Literal::I64(v) => Some(format!("i64 = {v}")),
                    naga::Literal::U64(v) => Some(format!("u64 = {v}")),
                    naga::Literal::AbstractInt(v) => Some(format!("i64 = {v}")),
                    naga::Literal::AbstractFloat(v) => Some(format!("f64 = {v}")),
                    naga::Literal::F16(v) => Some(format!("f16 = {v}")),
                },
                _ => continue,
            };

            if let Some(type_and_value) = type_and_value
                && let Some(mangled_name) = t.1.name.as_ref()
            {
                let (m, name) = mod_name_from_mangled(mangled_name);
                let constant_str = vec![format!("pub const {name}: {type_and_value};")];

                let map = if m == mod_name || m.is_empty() {
                    &mut constants
                } else {
                    match modules.get_mut(&m).unwrap() {
                        ModuleInfo::Include {
                            constants,
                            types: _,
                        } => constants,
                        ModuleInfo::File { .. } => panic!("Unsupported export type"),
                    }
                };
                map.insert(name.clone(), constant_str);
            }
        }

        let mut types = HashMap::new();

        for t in module.types.iter() {
            if let naga::TypeInner::Struct { members, span: _ } = &t.1.inner {
                if members.is_empty() {
                    continue;
                }

                let mangled_name = t.1.name.as_ref().unwrap();

                // Ignore some builtins.
                if mangled_name.contains("__atomic_compare_exchange_result") {
                    continue;
                }

                let (m, name) = mod_name_from_mangled(mangled_name);

                let max_align = members
                    .iter()
                    .map(|x| alignment_of(x.ty, ctx))
                    .max()
                    .unwrap();

                let mut struct_str = vec![format!("#[repr(C, align({max_align}))]")];
                struct_str.push(
                    "#[derive(bytemuck::Pod, bytemuck::Zeroable, Debug, PartialEq, Clone, Copy)]"
                        .to_owned(),
                );
                struct_str.push(format!("pub struct {name} {{"));
                for member in members {
                    let rust_name = rust_type_name(member.ty, ctx);

                    struct_str.push(
                        format!("    pub {}: {},", member.name.as_ref().unwrap(), rust_name)
                            .to_owned(),
                    );
                }
                struct_str.push("}".to_owned());

                let map = if m == mod_name || m.is_empty() {
                    &mut types
                } else {
                    match modules.get_mut(&m).unwrap() {
                        ModuleInfo::Include {
                            constants: _,
                            types,
                        } => types,
                        ModuleInfo::File { .. } => panic!("Unsupported export type"),
                    }
                };
                map.insert(name.clone(), struct_str);
            }
        }

        modules.insert(
            mod_name,
            ModuleInfo::File {
                path: (*path).to_owned(),
                wg_size: entry.workgroup_size,
                constants,
                types,
            },
        );
    }

    // Make sure output is ordered deterministically.
    let mut mods: Vec<_> = modules.iter().collect();
    mods.sort_by_key(|x| x.0.clone());

    for m in mods {
        match m.1 {
            ModuleInfo::Include { constants, types } => {
                code.add_line("#[rustfmt::skip]");
                code.add_line(format!("pub mod {} {{", m.0));

                let mut writes: Vec<_> = constants.iter().chain(types.iter()).collect();
                writes.sort_by_key(|x| x.0.clone());
                for c in writes {
                    code.add_lines(c.1);
                }
                code.add_line("}");
            }

            ModuleInfo::File {
                path,
                constants,
                types,
                wg_size,
            } => {
                code.add_line("#[rustfmt::skip]");
                code.add_line(format!("pub mod {} {{", m.0));

                let [wg_x, wg_y, wg_z] = wg_size;
                code.add_line(format!(
                    "pub const WORKGROUP_SIZE: [u32; 3] = [{wg_x}, {wg_y}, {wg_z}];"
                ));

                let mut writes: Vec<_> = constants.iter().chain(types.iter()).collect();
                writes.sort_by_key(|x| x.0.clone());
                for c in writes {
                    code.add_lines(c.1);
                }

                // Get the input path relative to the output directory
                let relative_path = pathdiff::diff_paths(
                    Path::new(&path),
                    Path::new(&output_path).parent().unwrap(),
                )
                .unwrap();

                let relative_path_str = relative_path.to_string_lossy().replace('\\', "\\\\");

                code.add_lines(&[
                    "",
                    "pub(crate) fn create_shader_source(",
                    "   shader_defs: std::collections::HashMap<String, naga_oil::compose::ShaderDefValue>",
                    ") -> wgpu::naga::Module {",
                    "super::create_composer().make_naga_module(naga_oil::compose::NagaModuleDescriptor {",
                    &format!("source: include_str!(\"{relative_path_str}\"),"),
                    &format!("file_path: \"{relative_path_str}\","),
                    "shader_defs,",
                    "..Default::default()",
                    "}).expect(\"Failed to add module\")",
                    "}",
                    "}",
                ]);
            }
        }
    }

    std::fs::write(output_path, code.string())?;
    Ok(())
}
